-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[ConstantTime][LLVM] Add llvm.ct.select intrinsic with generic SelectionDAG lowering #166702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
0664c25 to
80a83ad
Compare
|
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-backend-x86 Author: Julius Alexandre (wizardengineer) ChangesHere's the updated PR description with the actual PR links: This is in reference to our [RFC] Constant-Time Coding Support proposal. SummaryThis PR introduces core infrastructure for constant-time selection operations in LLVM, providing a foundation for Changes1. Core Intrinsic DefinitionAdds the
2. Generic Fallback ImplementationImplements architecture-agnostic SelectionDAG lowering using bitwise operations:
The fallback implementation converts boolean conditions into bitmasks and uses bitwise arithmetic to achieve This approach guarantees:
3. Test CoverageIncludes basic test cases demonstrating fallback functionality:
These tests verify that the intrinsic correctly lowers to constant-time bitwise operations on architectures without Architecture SupportThis PR provides the fallback implementation that works on all architectures. Subsequent PRs in the stack will add:
Security PropertiesThe fallback implementation ensures:
Related PRsThis is part of a stacked PR series implementing constant-time selection:
TestingAll changes pass existing regression tests. New tests verify:
ReferencesPatch is 70.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166702.diff 19 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index ff3dd0d4c3c51..656f6e718f029 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -783,6 +783,10 @@ enum NodeType {
/// i1 then the high bits must conform to getBooleanContents.
SELECT,
+ /// Constant-time Select, implemented with CMOV instruction. This is used to
+ /// implement constant-time select.
+ CTSELECT,
+
/// Select with a vector condition (op #0) and two vector operands (ops #1
/// and #2), returning a vector result. All vectors have the same length.
/// Much like the scalar select and setcc, each bit in the condition selects
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1a5ffb38f2568..b5debd490d9cb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1352,6 +1352,13 @@ class SelectionDAG {
return getNode(Opcode, DL, VT, Cond, LHS, RHS, Flags);
}
+ SDValue getCTSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS,
+ SDValue RHS, SDNodeFlags Flags = SDNodeFlags()) {
+ assert(LHS.getValueType() == VT && RHS.getValueType() == VT &&
+ "Cannot use select on differing types");
+ return getNode(ISD::CTSELECT, DL, VT, Cond, LHS, RHS, Flags);
+ }
+
/// Helper function to make it easier to build SelectCC's if you just have an
/// ISD::CondCode instead of an SDValue.
SDValue getSelectCC(const SDLoc &DL, SDValue LHS, SDValue RHS, SDValue True,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 1759463ea7965..8e18eb2f7db0e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -435,6 +435,9 @@ struct SDNodeFlags {
NonNeg | NoNaNs | NoInfs | SameSign | InBounds,
FastMathFlags = NoNaNs | NoInfs | NoSignedZeros | AllowReciprocal |
AllowContract | ApproximateFuncs | AllowReassociation,
+
+ // Flag for disabling optimization
+ NoMerge = 1 << 15,
};
/// Default constructor turns off all optimization flags.
@@ -486,7 +489,6 @@ struct SDNodeFlags {
bool hasNoFPExcept() const { return Flags & NoFPExcept; }
bool hasUnpredictable() const { return Flags & Unpredictable; }
bool hasInBounds() const { return Flags & InBounds; }
-
bool operator==(const SDNodeFlags &Other) const {
return Flags == Other.Flags;
}
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 78f63b4406eb0..8198485803d8b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -242,11 +242,15 @@ class LLVM_ABI TargetLoweringBase {
/// Enum that describes what type of support for selects the target has.
enum SelectSupportKind {
- ScalarValSelect, // The target supports scalar selects (ex: cmov).
- ScalarCondVectorVal, // The target supports selects with a scalar condition
- // and vector values (ex: cmov).
- VectorMaskSelect // The target supports vector selects with a vector
- // mask (ex: x86 blends).
+ ScalarValSelect, // The target supports scalar selects (ex: cmov).
+ ScalarCondVectorVal, // The target supports selects with a scalar condition
+ // and vector values (ex: cmov).
+ VectorMaskSelect, // The target supports vector selects with a vector
+ // mask (ex: x86 blends).
+ CtSelect, // The target implements a custom constant-time select.
+ ScalarCondVectorValCtSelect, // The target supports selects with a scalar
+ // condition and vector values.
+ VectorMaskValCtSelect, // The target supports vector selects with a vector
};
/// Enum that specifies what an atomic load/AtomicRMWInst is expanded
@@ -476,8 +480,8 @@ class LLVM_ABI TargetLoweringBase {
MachineMemOperand::Flags
getVPIntrinsicMemOperandFlags(const VPIntrinsic &VPIntrin) const;
- virtual bool isSelectSupported(SelectSupportKind /*kind*/) const {
- return true;
+ virtual bool isSelectSupported(SelectSupportKind kind) const {
+ return kind != CtSelect;
}
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 6a079f62dd9cf..d41c61777089d 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1825,6 +1825,15 @@ def int_coro_subfn_addr : DefaultAttrsIntrinsic<
[IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>,
NoCapture<ArgIndex<0>>]>;
+///===-------------------------- Constant Time Intrinsics
+///--------------------------===//
+//
+// Intrinsic to support constant time select
+def int_ct_select
+ : DefaultAttrsIntrinsic<[llvm_any_ty],
+ [llvm_i1_ty, LLVMMatchType<0>, LLVMMatchType<0>],
+ [IntrWriteMem, IntrWillReturn, NoUndef<RetIndex>]>;
+
///===-------------------------- Other Intrinsics --------------------------===//
//
// TODO: We should introduce a new memory kind fo traps (and other side effects
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 07a858fd682fc..de4abd713d3cf 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -214,6 +214,11 @@ def SDTSelect : SDTypeProfile<1, 3, [ // select
SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>
]>;
+def SDTCtSelect
+ : SDTypeProfile<1, 3,
+ [ // ctselect
+ SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>]>;
+
def SDTVSelect : SDTypeProfile<1, 3, [ // vselect
SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>
]>;
@@ -717,6 +722,7 @@ def reset_fpmode : SDNode<"ISD::RESET_FPMODE", SDTNone, [SDNPHasChain]>;
def setcc : SDNode<"ISD::SETCC" , SDTSetCC>;
def select : SDNode<"ISD::SELECT" , SDTSelect>;
+def ctselect : SDNode<"ISD::CTSELECT", SDTCtSelect>;
def vselect : SDNode<"ISD::VSELECT" , SDTVSelect>;
def selectcc : SDNode<"ISD::SELECT_CC" , SDTSelectCC>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 46c4bb85a7420..28fcebbb4a92a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -484,6 +484,7 @@ namespace {
SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
SDValue visitCTPOP(SDNode *N);
SDValue visitSELECT(SDNode *N);
+ SDValue visitCTSELECT(SDNode *N);
SDValue visitVSELECT(SDNode *N);
SDValue visitVP_SELECT(SDNode *N);
SDValue visitSELECT_CC(SDNode *N);
@@ -1898,6 +1899,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
}
SDValue DAGCombiner::visit(SDNode *N) {
+
// clang-format off
switch (N->getOpcode()) {
default: break;
@@ -1968,6 +1970,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
case ISD::CTPOP: return visitCTPOP(N);
case ISD::SELECT: return visitSELECT(N);
+ case ISD::CTSELECT: return visitCTSELECT(N);
case ISD::VSELECT: return visitVSELECT(N);
case ISD::SELECT_CC: return visitSELECT_CC(N);
case ISD::SETCC: return visitSETCC(N);
@@ -6032,6 +6035,7 @@ static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
break;
case ISD::SELECT:
+ case ISD::CTSELECT:
case ISD::VSELECT:
if (N0.getOperand(0).getOpcode() != ISD::SETCC)
return SDValue();
@@ -12184,8 +12188,9 @@ template <class MatchContextClass>
static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
SelectionDAG &DAG) {
assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
- N->getOpcode() == ISD::VP_SELECT) &&
- "Expected a (v)(vp.)select");
+ N->getOpcode() == ISD::VP_SELECT ||
+ N->getOpcode() == ISD::CTSELECT) &&
+ "Expected a (v)(vp.)(ct) select");
SDValue Cond = N->getOperand(0);
SDValue T = N->getOperand(1), F = N->getOperand(2);
EVT VT = N->getValueType(0);
@@ -12547,6 +12552,109 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitCTSELECT(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue N2 = N->getOperand(2);
+ EVT VT = N->getValueType(0);
+ EVT VT0 = N0.getValueType();
+ SDLoc DL(N);
+ SDNodeFlags Flags = N->getFlags();
+
+ if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
+ return V;
+
+ // ctselect (not Cond), N1, N2 -> ctselect Cond, N2, N1
+ if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
+ SDValue SelectOp = DAG.getNode(ISD::CTSELECT, DL, VT, F, N2, N1);
+ SelectOp->setFlags(Flags);
+ return SelectOp;
+ }
+
+ if (VT0 == MVT::i1) {
+ // The code in this block deals with the following 2 equivalences:
+ // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
+ // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
+ // The target can specify its preferred form with the
+ // shouldNormalizeToSelectSequence() callback. However we always transform
+ // to the right anyway if we find the inner select exists in the DAG anyway
+ // and we always transform to the left side if we know that we can further
+ // optimize the combination of the conditions.
+ bool normalizeToSequence =
+ TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
+ // ctselect (and Cond0, Cond1), X, Y
+ // -> ctselect Cond0, (ctselect Cond1, X, Y), Y
+ if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
+ SDValue Cond0 = N0->getOperand(0);
+ SDValue Cond1 = N0->getOperand(1);
+ SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(),
+ Cond1, N1, N2, Flags);
+ if (normalizeToSequence || !InnerSelect.use_empty())
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0,
+ InnerSelect, N2, Flags);
+ // Cleanup on failure.
+ if (InnerSelect.use_empty())
+ recursivelyDeleteUnusedNodes(InnerSelect.getNode());
+ }
+ // ctselect (or Cond0, Cond1), X, Y -> ctselect Cond0, X, (ctselect Cond1,
+ // X, Y)
+ if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
+ SDValue Cond0 = N0->getOperand(0);
+ SDValue Cond1 = N0->getOperand(1);
+ SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(),
+ Cond1, N1, N2, Flags);
+ if (normalizeToSequence || !InnerSelect.use_empty())
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0, N1,
+ InnerSelect, Flags);
+ // Cleanup on failure.
+ if (InnerSelect.use_empty())
+ recursivelyDeleteUnusedNodes(InnerSelect.getNode());
+ }
+
+ // ctselect Cond0, (ctselect Cond1, X, Y), Y -> ctselect (and Cond0, Cond1),
+ // X, Y
+ if (N1->getOpcode() == ISD::CTSELECT && N1->hasOneUse()) {
+ SDValue N1_0 = N1->getOperand(0);
+ SDValue N1_1 = N1->getOperand(1);
+ SDValue N1_2 = N1->getOperand(2);
+ if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
+ // Create the actual and node if we can generate good code for it.
+ if (!normalizeToSequence) {
+ SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), And, N1_1,
+ N2, Flags);
+ }
+ // Otherwise see if we can optimize the "and" to a better pattern.
+ if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined,
+ N1_1, N2, Flags);
+ }
+ }
+ }
+ // ctselect Cond0, X, (ctselect Cond1, X, Y) -> ctselect (or Cond0, Cond1),
+ // X, Y
+ if (N2->getOpcode() == ISD::CTSELECT && N2->hasOneUse()) {
+ SDValue N2_0 = N2->getOperand(0);
+ SDValue N2_1 = N2->getOperand(1);
+ SDValue N2_2 = N2->getOperand(2);
+ if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
+ // Create the actual or node if we can generate good code for it.
+ if (!normalizeToSequence) {
+ SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Or, N1, N2_2,
+ Flags);
+ }
+ // Otherwise see if we can optimize to a better pattern.
+ if (SDValue Combined = visitORLike(N0, N2_0, DL))
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined, N1,
+ N2_2, Flags);
+ }
+ }
+ }
+
+ return SDValue();
+}
+
// This function assumes all the vselect's arguments are CONCAT_VECTOR
// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 431a81002074f..8178fd8981519 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -4136,6 +4136,46 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
}
Results.push_back(Tmp1);
break;
+ case ISD::CTSELECT: {
+ Tmp1 = Node->getOperand(0);
+ Tmp2 = Node->getOperand(1);
+ Tmp3 = Node->getOperand(2);
+ EVT VT = Tmp2.getValueType();
+ if (VT.isVector()) {
+ SmallVector<SDValue> Elements;
+ unsigned NumElements = VT.getVectorNumElements();
+ EVT ScalarVT = VT.getScalarType();
+ for (unsigned Idx = 0; Idx < NumElements; ++Idx) {
+ SDValue IdxVal = DAG.getConstant(Idx, dl, MVT::i64);
+ SDValue TVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp2, IdxVal);
+ SDValue FVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp3, IdxVal);
+ Elements.push_back(
+ DAG.getCTSelect(dl, ScalarVT, Tmp1, TVal, FVal, Node->getFlags()));
+ }
+ Tmp1 = DAG.getBuildVector(VT, dl, Elements);
+ } else if (VT.isFloatingPoint()) {
+ EVT IntegerVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+ Tmp2 = DAG.getBitcast(IntegerVT, Tmp2);
+ Tmp3 = DAG.getBitcast(IntegerVT, Tmp3);
+ Tmp1 = DAG.getBitcast(VT, DAG.getCTSelect(dl, IntegerVT, Tmp1, Tmp2, Tmp3,
+ Node->getFlags()));
+ } else {
+ assert(VT.isInteger());
+ EVT HalfVT = VT.getHalfSizedIntegerVT(*DAG.getContext());
+ auto [Tmp2Lo, Tmp2Hi] = DAG.SplitScalar(Tmp2, dl, HalfVT, HalfVT);
+ auto [Tmp3Lo, Tmp3Hi] = DAG.SplitScalar(Tmp3, dl, HalfVT, HalfVT);
+ SDValue ResLo =
+ DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Lo, Tmp3Lo, Node->getFlags());
+ SDValue ResHi =
+ DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Hi, Tmp3Hi, Node->getFlags());
+ Tmp1 = DAG.getNode(ISD::BUILD_PAIR, dl, VT, ResLo, ResHi);
+ Tmp1->setFlags(Node->getFlags());
+ }
+ Results.push_back(Tmp1);
+ break;
+ }
case ISD::BR_JT: {
SDValue Chain = Node->getOperand(0);
SDValue Table = Node->getOperand(1);
@@ -5474,7 +5514,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2));
break;
}
- case ISD::SELECT: {
+ case ISD::SELECT:
+ case ISD::CTSELECT: {
unsigned ExtOp, TruncOp;
if (Node->getValueType(0).isVector() ||
Node->getValueType(0).getSizeInBits() == NVT.getSizeInBits()) {
@@ -5492,7 +5533,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1));
Tmp3 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(2));
// Perform the larger operation, then round down.
- Tmp1 = DAG.getSelect(dl, NVT, Tmp1, Tmp2, Tmp3);
+ Tmp1 = DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1, Tmp2, Tmp3);
+ Tmp1->setFlags(Node->getFlags());
if (TruncOp != ISD::FP_ROUND)
Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1);
else
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 58983cb57d7f6..855a15a744cfe 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -159,6 +159,7 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
case ISD::ATOMIC_LOAD: R = SoftenFloatRes_ATOMIC_LOAD(N); break;
case ISD::ATOMIC_SWAP: R = BitcastToInt_ATOMIC_SWAP(N); break;
case ISD::SELECT: R = SoftenFloatRes_SELECT(N); break;
+ case ISD::CTSELECT: R = SoftenFloatRes_CTSELECT(N); break;
case ISD::SELECT_CC: R = SoftenFloatRes_SELECT_CC(N); break;
case ISD::FREEZE: R = SoftenFloatRes_FREEZE(N); break;
case ISD::STRICT_SINT_TO_FP:
@@ -1041,6 +1042,13 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_SELECT(SDNode *N) {
LHS.getValueType(), N->getOperand(0), LHS, RHS);
}
+SDValue DAGTypeLegalizer::SoftenFloatRes_CTSELECT(SDNode *N) {
+ SDValue LHS = GetSoftenedFloat(N->getOperand(1));
+ SDValue RHS = GetSoftenedFloat(N->getOperand(2));
+ return DAG.getCTSelect(SDLoc(N), LHS.getValueType(), N->getOperand(0), LHS,
+ RHS);
+}
+
SDValue DAGTypeLegalizer::SoftenFloatRes_SELECT_CC(SDNode *N) {
SDValue LHS = GetSoftenedFloat(N->getOperand(2));
SDValue RHS = GetSoftenedFloat(N->getOperand(3));
@@ -1561,6 +1569,7 @@ void DAGTypeLegalizer::ExpandFloatResult(SDNode *N, unsigned ResNo) {
case ISD::POISON:
case ISD::UNDEF: SplitRes_UNDEF(N, Lo, Hi); break;
case ISD::SELECT: SplitRes_Select(N, Lo, Hi); break;
+ case ISD::CTSELECT: SplitRes_Select(N, Lo, Hi); break;
case ISD::SELECT_CC: SplitRes_SELECT_CC(N, Lo, Hi); break;
case ISD::MERGE_VALUES: ExpandRes_MERGE_VALUES(N, ResNo, Lo, Hi); break;
@@ -2917,6 +2926,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
R = PromoteFloatRes_ATOMIC_LOAD(N);
break;
case ISD::SELECT: R = PromoteFloatRes_SELECT(N); break;
+ case ISD::CTSELECT:
+ R = PromoteFloatRes_SELECT(N);
+ break;
case ISD::SELECT_CC: R = PromoteFloatRes_SELECT_CC(N); break;
case ISD::SINT_TO_FP:
@@ -3219,7 +3231,7 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_SELECT(SDNode *N) {
SDValue TrueVal = GetPromotedFloat(N->getOperand(1));
SDValue FalseVal = GetPromotedFloat(N->getOperand(2));
- return DAG.getNode(ISD::SELECT, SDLoc(N), TrueVal->getValueType(0),
+ return DAG.getNode(N->getOpcode(), SDLoc(N), TrueVal->getValueType(0),
N->getOperand(0), TrueVal, FalseVal);
}
@@ -3403,6 +3415,9 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
R = SoftPromoteHalfRes_ATOMIC_LOAD(N);
break;
case ISD::SELECT: R = SoftPromoteHalfRes_SELECT(N); break;
+ case ISD::CTSELECT:
+ R = SoftPromoteHalfRes_SELECT(N);
+ break;
case ISD::SELECT_CC: R = SoftPromoteHalfRes_SELECT_CC(N); break;
case ISD::STRICT_SINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 44e5a187c4281..0135b3195438b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -95,6 +95,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
Res = PromoteIntRes_VECTOR_COMPRESS(N);
break;
case ISD::SELECT:
+ case ISD::CTSELECT:
case ISD::VSELECT:
case ISD::VP_SELECT:
case ISD::VP_MERGE:
@@ -2013,6 +2014,9 @@...
[truncated]
|
|
@llvm/pr-subscribers-llvm-selectiondag Author: Julius Alexandre (wizardengineer) ChangesHere's the updated PR description with the actual PR links: This is in reference to our [RFC] Constant-Time Coding Support proposal. SummaryThis PR introduces core infrastructure for constant-time selection operations in LLVM, providing a foundation for Changes1. Core Intrinsic DefinitionAdds the
2. Generic Fallback ImplementationImplements architecture-agnostic SelectionDAG lowering using bitwise operations:
The fallback implementation converts boolean conditions into bitmasks and uses bitwise arithmetic to achieve This approach guarantees:
3. Test CoverageIncludes basic test cases demonstrating fallback functionality:
These tests verify that the intrinsic correctly lowers to constant-time bitwise operations on architectures without Architecture SupportThis PR provides the fallback implementation that works on all architectures. Subsequent PRs in the stack will add:
Security PropertiesThe fallback implementation ensures:
Related PRsThis is part of a stacked PR series implementing constant-time selection:
TestingAll changes pass existing regression tests. New tests verify:
ReferencesPatch is 70.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166702.diff 19 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index ff3dd0d4c3c51..656f6e718f029 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -783,6 +783,10 @@ enum NodeType {
/// i1 then the high bits must conform to getBooleanContents.
SELECT,
+ /// Constant-time Select, implemented with CMOV instruction. This is used to
+ /// implement constant-time select.
+ CTSELECT,
+
/// Select with a vector condition (op #0) and two vector operands (ops #1
/// and #2), returning a vector result. All vectors have the same length.
/// Much like the scalar select and setcc, each bit in the condition selects
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1a5ffb38f2568..b5debd490d9cb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1352,6 +1352,13 @@ class SelectionDAG {
return getNode(Opcode, DL, VT, Cond, LHS, RHS, Flags);
}
+ SDValue getCTSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS,
+ SDValue RHS, SDNodeFlags Flags = SDNodeFlags()) {
+ assert(LHS.getValueType() == VT && RHS.getValueType() == VT &&
+ "Cannot use select on differing types");
+ return getNode(ISD::CTSELECT, DL, VT, Cond, LHS, RHS, Flags);
+ }
+
/// Helper function to make it easier to build SelectCC's if you just have an
/// ISD::CondCode instead of an SDValue.
SDValue getSelectCC(const SDLoc &DL, SDValue LHS, SDValue RHS, SDValue True,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 1759463ea7965..8e18eb2f7db0e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -435,6 +435,9 @@ struct SDNodeFlags {
NonNeg | NoNaNs | NoInfs | SameSign | InBounds,
FastMathFlags = NoNaNs | NoInfs | NoSignedZeros | AllowReciprocal |
AllowContract | ApproximateFuncs | AllowReassociation,
+
+ // Flag for disabling optimization
+ NoMerge = 1 << 15,
};
/// Default constructor turns off all optimization flags.
@@ -486,7 +489,6 @@ struct SDNodeFlags {
bool hasNoFPExcept() const { return Flags & NoFPExcept; }
bool hasUnpredictable() const { return Flags & Unpredictable; }
bool hasInBounds() const { return Flags & InBounds; }
-
bool operator==(const SDNodeFlags &Other) const {
return Flags == Other.Flags;
}
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 78f63b4406eb0..8198485803d8b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -242,11 +242,15 @@ class LLVM_ABI TargetLoweringBase {
/// Enum that describes what type of support for selects the target has.
enum SelectSupportKind {
- ScalarValSelect, // The target supports scalar selects (ex: cmov).
- ScalarCondVectorVal, // The target supports selects with a scalar condition
- // and vector values (ex: cmov).
- VectorMaskSelect // The target supports vector selects with a vector
- // mask (ex: x86 blends).
+ ScalarValSelect, // The target supports scalar selects (ex: cmov).
+ ScalarCondVectorVal, // The target supports selects with a scalar condition
+ // and vector values (ex: cmov).
+ VectorMaskSelect, // The target supports vector selects with a vector
+ // mask (ex: x86 blends).
+ CtSelect, // The target implements a custom constant-time select.
+ ScalarCondVectorValCtSelect, // The target supports selects with a scalar
+ // condition and vector values.
+ VectorMaskValCtSelect, // The target supports vector selects with a vector
};
/// Enum that specifies what an atomic load/AtomicRMWInst is expanded
@@ -476,8 +480,8 @@ class LLVM_ABI TargetLoweringBase {
MachineMemOperand::Flags
getVPIntrinsicMemOperandFlags(const VPIntrinsic &VPIntrin) const;
- virtual bool isSelectSupported(SelectSupportKind /*kind*/) const {
- return true;
+ virtual bool isSelectSupported(SelectSupportKind kind) const {
+ return kind != CtSelect;
}
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 6a079f62dd9cf..d41c61777089d 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1825,6 +1825,15 @@ def int_coro_subfn_addr : DefaultAttrsIntrinsic<
[IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>,
NoCapture<ArgIndex<0>>]>;
+///===-------------------------- Constant Time Intrinsics
+///--------------------------===//
+//
+// Intrinsic to support constant time select
+def int_ct_select
+ : DefaultAttrsIntrinsic<[llvm_any_ty],
+ [llvm_i1_ty, LLVMMatchType<0>, LLVMMatchType<0>],
+ [IntrWriteMem, IntrWillReturn, NoUndef<RetIndex>]>;
+
///===-------------------------- Other Intrinsics --------------------------===//
//
// TODO: We should introduce a new memory kind fo traps (and other side effects
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 07a858fd682fc..de4abd713d3cf 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -214,6 +214,11 @@ def SDTSelect : SDTypeProfile<1, 3, [ // select
SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>
]>;
+def SDTCtSelect
+ : SDTypeProfile<1, 3,
+ [ // ctselect
+ SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>]>;
+
def SDTVSelect : SDTypeProfile<1, 3, [ // vselect
SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>
]>;
@@ -717,6 +722,7 @@ def reset_fpmode : SDNode<"ISD::RESET_FPMODE", SDTNone, [SDNPHasChain]>;
def setcc : SDNode<"ISD::SETCC" , SDTSetCC>;
def select : SDNode<"ISD::SELECT" , SDTSelect>;
+def ctselect : SDNode<"ISD::CTSELECT", SDTCtSelect>;
def vselect : SDNode<"ISD::VSELECT" , SDTVSelect>;
def selectcc : SDNode<"ISD::SELECT_CC" , SDTSelectCC>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 46c4bb85a7420..28fcebbb4a92a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -484,6 +484,7 @@ namespace {
SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
SDValue visitCTPOP(SDNode *N);
SDValue visitSELECT(SDNode *N);
+ SDValue visitCTSELECT(SDNode *N);
SDValue visitVSELECT(SDNode *N);
SDValue visitVP_SELECT(SDNode *N);
SDValue visitSELECT_CC(SDNode *N);
@@ -1898,6 +1899,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
}
SDValue DAGCombiner::visit(SDNode *N) {
+
// clang-format off
switch (N->getOpcode()) {
default: break;
@@ -1968,6 +1970,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
case ISD::CTPOP: return visitCTPOP(N);
case ISD::SELECT: return visitSELECT(N);
+ case ISD::CTSELECT: return visitCTSELECT(N);
case ISD::VSELECT: return visitVSELECT(N);
case ISD::SELECT_CC: return visitSELECT_CC(N);
case ISD::SETCC: return visitSETCC(N);
@@ -6032,6 +6035,7 @@ static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
break;
case ISD::SELECT:
+ case ISD::CTSELECT:
case ISD::VSELECT:
if (N0.getOperand(0).getOpcode() != ISD::SETCC)
return SDValue();
@@ -12184,8 +12188,9 @@ template <class MatchContextClass>
static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
SelectionDAG &DAG) {
assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
- N->getOpcode() == ISD::VP_SELECT) &&
- "Expected a (v)(vp.)select");
+ N->getOpcode() == ISD::VP_SELECT ||
+ N->getOpcode() == ISD::CTSELECT) &&
+ "Expected a (v)(vp.)(ct) select");
SDValue Cond = N->getOperand(0);
SDValue T = N->getOperand(1), F = N->getOperand(2);
EVT VT = N->getValueType(0);
@@ -12547,6 +12552,109 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitCTSELECT(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue N2 = N->getOperand(2);
+ EVT VT = N->getValueType(0);
+ EVT VT0 = N0.getValueType();
+ SDLoc DL(N);
+ SDNodeFlags Flags = N->getFlags();
+
+ if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
+ return V;
+
+ // ctselect (not Cond), N1, N2 -> ctselect Cond, N2, N1
+ if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
+ SDValue SelectOp = DAG.getNode(ISD::CTSELECT, DL, VT, F, N2, N1);
+ SelectOp->setFlags(Flags);
+ return SelectOp;
+ }
+
+ if (VT0 == MVT::i1) {
+ // The code in this block deals with the following 2 equivalences:
+ // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
+ // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
+ // The target can specify its preferred form with the
+ // shouldNormalizeToSelectSequence() callback. However we always transform
+ // to the right anyway if we find the inner select exists in the DAG anyway
+ // and we always transform to the left side if we know that we can further
+ // optimize the combination of the conditions.
+ bool normalizeToSequence =
+ TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
+ // ctselect (and Cond0, Cond1), X, Y
+ // -> ctselect Cond0, (ctselect Cond1, X, Y), Y
+ if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
+ SDValue Cond0 = N0->getOperand(0);
+ SDValue Cond1 = N0->getOperand(1);
+ SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(),
+ Cond1, N1, N2, Flags);
+ if (normalizeToSequence || !InnerSelect.use_empty())
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0,
+ InnerSelect, N2, Flags);
+ // Cleanup on failure.
+ if (InnerSelect.use_empty())
+ recursivelyDeleteUnusedNodes(InnerSelect.getNode());
+ }
+ // ctselect (or Cond0, Cond1), X, Y -> ctselect Cond0, X, (ctselect Cond1,
+ // X, Y)
+ if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
+ SDValue Cond0 = N0->getOperand(0);
+ SDValue Cond1 = N0->getOperand(1);
+ SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(),
+ Cond1, N1, N2, Flags);
+ if (normalizeToSequence || !InnerSelect.use_empty())
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0, N1,
+ InnerSelect, Flags);
+ // Cleanup on failure.
+ if (InnerSelect.use_empty())
+ recursivelyDeleteUnusedNodes(InnerSelect.getNode());
+ }
+
+ // ctselect Cond0, (ctselect Cond1, X, Y), Y -> ctselect (and Cond0, Cond1),
+ // X, Y
+ if (N1->getOpcode() == ISD::CTSELECT && N1->hasOneUse()) {
+ SDValue N1_0 = N1->getOperand(0);
+ SDValue N1_1 = N1->getOperand(1);
+ SDValue N1_2 = N1->getOperand(2);
+ if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
+ // Create the actual and node if we can generate good code for it.
+ if (!normalizeToSequence) {
+ SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), And, N1_1,
+ N2, Flags);
+ }
+ // Otherwise see if we can optimize the "and" to a better pattern.
+ if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined,
+ N1_1, N2, Flags);
+ }
+ }
+ }
+ // ctselect Cond0, X, (ctselect Cond1, X, Y) -> ctselect (or Cond0, Cond1),
+ // X, Y
+ if (N2->getOpcode() == ISD::CTSELECT && N2->hasOneUse()) {
+ SDValue N2_0 = N2->getOperand(0);
+ SDValue N2_1 = N2->getOperand(1);
+ SDValue N2_2 = N2->getOperand(2);
+ if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
+ // Create the actual or node if we can generate good code for it.
+ if (!normalizeToSequence) {
+ SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Or, N1, N2_2,
+ Flags);
+ }
+ // Otherwise see if we can optimize to a better pattern.
+ if (SDValue Combined = visitORLike(N0, N2_0, DL))
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined, N1,
+ N2_2, Flags);
+ }
+ }
+ }
+
+ return SDValue();
+}
+
// This function assumes all the vselect's arguments are CONCAT_VECTOR
// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 431a81002074f..8178fd8981519 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -4136,6 +4136,46 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
}
Results.push_back(Tmp1);
break;
+ case ISD::CTSELECT: {
+ Tmp1 = Node->getOperand(0);
+ Tmp2 = Node->getOperand(1);
+ Tmp3 = Node->getOperand(2);
+ EVT VT = Tmp2.getValueType();
+ if (VT.isVector()) {
+ SmallVector<SDValue> Elements;
+ unsigned NumElements = VT.getVectorNumElements();
+ EVT ScalarVT = VT.getScalarType();
+ for (unsigned Idx = 0; Idx < NumElements; ++Idx) {
+ SDValue IdxVal = DAG.getConstant(Idx, dl, MVT::i64);
+ SDValue TVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp2, IdxVal);
+ SDValue FVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp3, IdxVal);
+ Elements.push_back(
+ DAG.getCTSelect(dl, ScalarVT, Tmp1, TVal, FVal, Node->getFlags()));
+ }
+ Tmp1 = DAG.getBuildVector(VT, dl, Elements);
+ } else if (VT.isFloatingPoint()) {
+ EVT IntegerVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+ Tmp2 = DAG.getBitcast(IntegerVT, Tmp2);
+ Tmp3 = DAG.getBitcast(IntegerVT, Tmp3);
+ Tmp1 = DAG.getBitcast(VT, DAG.getCTSelect(dl, IntegerVT, Tmp1, Tmp2, Tmp3,
+ Node->getFlags()));
+ } else {
+ assert(VT.isInteger());
+ EVT HalfVT = VT.getHalfSizedIntegerVT(*DAG.getContext());
+ auto [Tmp2Lo, Tmp2Hi] = DAG.SplitScalar(Tmp2, dl, HalfVT, HalfVT);
+ auto [Tmp3Lo, Tmp3Hi] = DAG.SplitScalar(Tmp3, dl, HalfVT, HalfVT);
+ SDValue ResLo =
+ DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Lo, Tmp3Lo, Node->getFlags());
+ SDValue ResHi =
+ DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Hi, Tmp3Hi, Node->getFlags());
+ Tmp1 = DAG.getNode(ISD::BUILD_PAIR, dl, VT, ResLo, ResHi);
+ Tmp1->setFlags(Node->getFlags());
+ }
+ Results.push_back(Tmp1);
+ break;
+ }
case ISD::BR_JT: {
SDValue Chain = Node->getOperand(0);
SDValue Table = Node->getOperand(1);
@@ -5474,7 +5514,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2));
break;
}
- case ISD::SELECT: {
+ case ISD::SELECT:
+ case ISD::CTSELECT: {
unsigned ExtOp, TruncOp;
if (Node->getValueType(0).isVector() ||
Node->getValueType(0).getSizeInBits() == NVT.getSizeInBits()) {
@@ -5492,7 +5533,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1));
Tmp3 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(2));
// Perform the larger operation, then round down.
- Tmp1 = DAG.getSelect(dl, NVT, Tmp1, Tmp2, Tmp3);
+ Tmp1 = DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1, Tmp2, Tmp3);
+ Tmp1->setFlags(Node->getFlags());
if (TruncOp != ISD::FP_ROUND)
Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1);
else
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 58983cb57d7f6..855a15a744cfe 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -159,6 +159,7 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
case ISD::ATOMIC_LOAD: R = SoftenFloatRes_ATOMIC_LOAD(N); break;
case ISD::ATOMIC_SWAP: R = BitcastToInt_ATOMIC_SWAP(N); break;
case ISD::SELECT: R = SoftenFloatRes_SELECT(N); break;
+ case ISD::CTSELECT: R = SoftenFloatRes_CTSELECT(N); break;
case ISD::SELECT_CC: R = SoftenFloatRes_SELECT_CC(N); break;
case ISD::FREEZE: R = SoftenFloatRes_FREEZE(N); break;
case ISD::STRICT_SINT_TO_FP:
@@ -1041,6 +1042,13 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_SELECT(SDNode *N) {
LHS.getValueType(), N->getOperand(0), LHS, RHS);
}
+SDValue DAGTypeLegalizer::SoftenFloatRes_CTSELECT(SDNode *N) {
+ SDValue LHS = GetSoftenedFloat(N->getOperand(1));
+ SDValue RHS = GetSoftenedFloat(N->getOperand(2));
+ return DAG.getCTSelect(SDLoc(N), LHS.getValueType(), N->getOperand(0), LHS,
+ RHS);
+}
+
SDValue DAGTypeLegalizer::SoftenFloatRes_SELECT_CC(SDNode *N) {
SDValue LHS = GetSoftenedFloat(N->getOperand(2));
SDValue RHS = GetSoftenedFloat(N->getOperand(3));
@@ -1561,6 +1569,7 @@ void DAGTypeLegalizer::ExpandFloatResult(SDNode *N, unsigned ResNo) {
case ISD::POISON:
case ISD::UNDEF: SplitRes_UNDEF(N, Lo, Hi); break;
case ISD::SELECT: SplitRes_Select(N, Lo, Hi); break;
+ case ISD::CTSELECT: SplitRes_Select(N, Lo, Hi); break;
case ISD::SELECT_CC: SplitRes_SELECT_CC(N, Lo, Hi); break;
case ISD::MERGE_VALUES: ExpandRes_MERGE_VALUES(N, ResNo, Lo, Hi); break;
@@ -2917,6 +2926,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
R = PromoteFloatRes_ATOMIC_LOAD(N);
break;
case ISD::SELECT: R = PromoteFloatRes_SELECT(N); break;
+ case ISD::CTSELECT:
+ R = PromoteFloatRes_SELECT(N);
+ break;
case ISD::SELECT_CC: R = PromoteFloatRes_SELECT_CC(N); break;
case ISD::SINT_TO_FP:
@@ -3219,7 +3231,7 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_SELECT(SDNode *N) {
SDValue TrueVal = GetPromotedFloat(N->getOperand(1));
SDValue FalseVal = GetPromotedFloat(N->getOperand(2));
- return DAG.getNode(ISD::SELECT, SDLoc(N), TrueVal->getValueType(0),
+ return DAG.getNode(N->getOpcode(), SDLoc(N), TrueVal->getValueType(0),
N->getOperand(0), TrueVal, FalseVal);
}
@@ -3403,6 +3415,9 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
R = SoftPromoteHalfRes_ATOMIC_LOAD(N);
break;
case ISD::SELECT: R = SoftPromoteHalfRes_SELECT(N); break;
+ case ISD::CTSELECT:
+ R = SoftPromoteHalfRes_SELECT(N);
+ break;
case ISD::SELECT_CC: R = SoftPromoteHalfRes_SELECT_CC(N); break;
case ISD::STRICT_SINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 44e5a187c4281..0135b3195438b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -95,6 +95,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
Res = PromoteIntRes_VECTOR_COMPRESS(N);
break;
case ISD::SELECT:
+ case ISD::CTSELECT:
case ISD::VSELECT:
case ISD::VP_SELECT:
case ISD::VP_MERGE:
@@ -2013,6 +2014,9 @@...
[truncated]
|
|
@llvm/pr-subscribers-backend-risc-v Author: Julius Alexandre (wizardengineer) ChangesHere's the updated PR description with the actual PR links: This is in reference to our [RFC] Constant-Time Coding Support proposal. SummaryThis PR introduces core infrastructure for constant-time selection operations in LLVM, providing a foundation for Changes1. Core Intrinsic DefinitionAdds the
2. Generic Fallback ImplementationImplements architecture-agnostic SelectionDAG lowering using bitwise operations:
The fallback implementation converts boolean conditions into bitmasks and uses bitwise arithmetic to achieve This approach guarantees:
3. Test CoverageIncludes basic test cases demonstrating fallback functionality:
These tests verify that the intrinsic correctly lowers to constant-time bitwise operations on architectures without Architecture SupportThis PR provides the fallback implementation that works on all architectures. Subsequent PRs in the stack will add:
Security PropertiesThe fallback implementation ensures:
Related PRsThis is part of a stacked PR series implementing constant-time selection:
TestingAll changes pass existing regression tests. New tests verify:
ReferencesPatch is 70.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166702.diff 19 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index ff3dd0d4c3c51..656f6e718f029 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -783,6 +783,10 @@ enum NodeType {
/// i1 then the high bits must conform to getBooleanContents.
SELECT,
+ /// Constant-time Select, implemented with CMOV instruction. This is used to
+ /// implement constant-time select.
+ CTSELECT,
+
/// Select with a vector condition (op #0) and two vector operands (ops #1
/// and #2), returning a vector result. All vectors have the same length.
/// Much like the scalar select and setcc, each bit in the condition selects
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1a5ffb38f2568..b5debd490d9cb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1352,6 +1352,13 @@ class SelectionDAG {
return getNode(Opcode, DL, VT, Cond, LHS, RHS, Flags);
}
+ SDValue getCTSelect(const SDLoc &DL, EVT VT, SDValue Cond, SDValue LHS,
+ SDValue RHS, SDNodeFlags Flags = SDNodeFlags()) {
+ assert(LHS.getValueType() == VT && RHS.getValueType() == VT &&
+ "Cannot use select on differing types");
+ return getNode(ISD::CTSELECT, DL, VT, Cond, LHS, RHS, Flags);
+ }
+
/// Helper function to make it easier to build SelectCC's if you just have an
/// ISD::CondCode instead of an SDValue.
SDValue getSelectCC(const SDLoc &DL, SDValue LHS, SDValue RHS, SDValue True,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 1759463ea7965..8e18eb2f7db0e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -435,6 +435,9 @@ struct SDNodeFlags {
NonNeg | NoNaNs | NoInfs | SameSign | InBounds,
FastMathFlags = NoNaNs | NoInfs | NoSignedZeros | AllowReciprocal |
AllowContract | ApproximateFuncs | AllowReassociation,
+
+ // Flag for disabling optimization
+ NoMerge = 1 << 15,
};
/// Default constructor turns off all optimization flags.
@@ -486,7 +489,6 @@ struct SDNodeFlags {
bool hasNoFPExcept() const { return Flags & NoFPExcept; }
bool hasUnpredictable() const { return Flags & Unpredictable; }
bool hasInBounds() const { return Flags & InBounds; }
-
bool operator==(const SDNodeFlags &Other) const {
return Flags == Other.Flags;
}
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 78f63b4406eb0..8198485803d8b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -242,11 +242,15 @@ class LLVM_ABI TargetLoweringBase {
/// Enum that describes what type of support for selects the target has.
enum SelectSupportKind {
- ScalarValSelect, // The target supports scalar selects (ex: cmov).
- ScalarCondVectorVal, // The target supports selects with a scalar condition
- // and vector values (ex: cmov).
- VectorMaskSelect // The target supports vector selects with a vector
- // mask (ex: x86 blends).
+ ScalarValSelect, // The target supports scalar selects (ex: cmov).
+ ScalarCondVectorVal, // The target supports selects with a scalar condition
+ // and vector values (ex: cmov).
+ VectorMaskSelect, // The target supports vector selects with a vector
+ // mask (ex: x86 blends).
+ CtSelect, // The target implements a custom constant-time select.
+ ScalarCondVectorValCtSelect, // The target supports selects with a scalar
+ // condition and vector values.
+ VectorMaskValCtSelect, // The target supports vector selects with a vector
};
/// Enum that specifies what an atomic load/AtomicRMWInst is expanded
@@ -476,8 +480,8 @@ class LLVM_ABI TargetLoweringBase {
MachineMemOperand::Flags
getVPIntrinsicMemOperandFlags(const VPIntrinsic &VPIntrin) const;
- virtual bool isSelectSupported(SelectSupportKind /*kind*/) const {
- return true;
+ virtual bool isSelectSupported(SelectSupportKind kind) const {
+ return kind != CtSelect;
}
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 6a079f62dd9cf..d41c61777089d 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1825,6 +1825,15 @@ def int_coro_subfn_addr : DefaultAttrsIntrinsic<
[IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>,
NoCapture<ArgIndex<0>>]>;
+///===-------------------------- Constant Time Intrinsics
+///--------------------------===//
+//
+// Intrinsic to support constant time select
+def int_ct_select
+ : DefaultAttrsIntrinsic<[llvm_any_ty],
+ [llvm_i1_ty, LLVMMatchType<0>, LLVMMatchType<0>],
+ [IntrWriteMem, IntrWillReturn, NoUndef<RetIndex>]>;
+
///===-------------------------- Other Intrinsics --------------------------===//
//
// TODO: We should introduce a new memory kind fo traps (and other side effects
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 07a858fd682fc..de4abd713d3cf 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -214,6 +214,11 @@ def SDTSelect : SDTypeProfile<1, 3, [ // select
SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>
]>;
+def SDTCtSelect
+ : SDTypeProfile<1, 3,
+ [ // ctselect
+ SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>]>;
+
def SDTVSelect : SDTypeProfile<1, 3, [ // vselect
SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>
]>;
@@ -717,6 +722,7 @@ def reset_fpmode : SDNode<"ISD::RESET_FPMODE", SDTNone, [SDNPHasChain]>;
def setcc : SDNode<"ISD::SETCC" , SDTSetCC>;
def select : SDNode<"ISD::SELECT" , SDTSelect>;
+def ctselect : SDNode<"ISD::CTSELECT", SDTCtSelect>;
def vselect : SDNode<"ISD::VSELECT" , SDTVSelect>;
def selectcc : SDNode<"ISD::SELECT_CC" , SDTSelectCC>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 46c4bb85a7420..28fcebbb4a92a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -484,6 +484,7 @@ namespace {
SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
SDValue visitCTPOP(SDNode *N);
SDValue visitSELECT(SDNode *N);
+ SDValue visitCTSELECT(SDNode *N);
SDValue visitVSELECT(SDNode *N);
SDValue visitVP_SELECT(SDNode *N);
SDValue visitSELECT_CC(SDNode *N);
@@ -1898,6 +1899,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
}
SDValue DAGCombiner::visit(SDNode *N) {
+
// clang-format off
switch (N->getOpcode()) {
default: break;
@@ -1968,6 +1970,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
case ISD::CTPOP: return visitCTPOP(N);
case ISD::SELECT: return visitSELECT(N);
+ case ISD::CTSELECT: return visitCTSELECT(N);
case ISD::VSELECT: return visitVSELECT(N);
case ISD::SELECT_CC: return visitSELECT_CC(N);
case ISD::SETCC: return visitSETCC(N);
@@ -6032,6 +6035,7 @@ static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
break;
case ISD::SELECT:
+ case ISD::CTSELECT:
case ISD::VSELECT:
if (N0.getOperand(0).getOpcode() != ISD::SETCC)
return SDValue();
@@ -12184,8 +12188,9 @@ template <class MatchContextClass>
static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
SelectionDAG &DAG) {
assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
- N->getOpcode() == ISD::VP_SELECT) &&
- "Expected a (v)(vp.)select");
+ N->getOpcode() == ISD::VP_SELECT ||
+ N->getOpcode() == ISD::CTSELECT) &&
+ "Expected a (v)(vp.)(ct) select");
SDValue Cond = N->getOperand(0);
SDValue T = N->getOperand(1), F = N->getOperand(2);
EVT VT = N->getValueType(0);
@@ -12547,6 +12552,109 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitCTSELECT(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue N2 = N->getOperand(2);
+ EVT VT = N->getValueType(0);
+ EVT VT0 = N0.getValueType();
+ SDLoc DL(N);
+ SDNodeFlags Flags = N->getFlags();
+
+ if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
+ return V;
+
+ // ctselect (not Cond), N1, N2 -> ctselect Cond, N2, N1
+ if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
+ SDValue SelectOp = DAG.getNode(ISD::CTSELECT, DL, VT, F, N2, N1);
+ SelectOp->setFlags(Flags);
+ return SelectOp;
+ }
+
+ if (VT0 == MVT::i1) {
+ // The code in this block deals with the following 2 equivalences:
+ // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
+ // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
+ // The target can specify its preferred form with the
+ // shouldNormalizeToSelectSequence() callback. However we always transform
+ // to the right anyway if we find the inner select exists in the DAG anyway
+ // and we always transform to the left side if we know that we can further
+ // optimize the combination of the conditions.
+ bool normalizeToSequence =
+ TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
+ // ctselect (and Cond0, Cond1), X, Y
+ // -> ctselect Cond0, (ctselect Cond1, X, Y), Y
+ if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
+ SDValue Cond0 = N0->getOperand(0);
+ SDValue Cond1 = N0->getOperand(1);
+ SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(),
+ Cond1, N1, N2, Flags);
+ if (normalizeToSequence || !InnerSelect.use_empty())
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0,
+ InnerSelect, N2, Flags);
+ // Cleanup on failure.
+ if (InnerSelect.use_empty())
+ recursivelyDeleteUnusedNodes(InnerSelect.getNode());
+ }
+ // ctselect (or Cond0, Cond1), X, Y -> ctselect Cond0, X, (ctselect Cond1,
+ // X, Y)
+ if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
+ SDValue Cond0 = N0->getOperand(0);
+ SDValue Cond1 = N0->getOperand(1);
+ SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(),
+ Cond1, N1, N2, Flags);
+ if (normalizeToSequence || !InnerSelect.use_empty())
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0, N1,
+ InnerSelect, Flags);
+ // Cleanup on failure.
+ if (InnerSelect.use_empty())
+ recursivelyDeleteUnusedNodes(InnerSelect.getNode());
+ }
+
+ // ctselect Cond0, (ctselect Cond1, X, Y), Y -> ctselect (and Cond0, Cond1),
+ // X, Y
+ if (N1->getOpcode() == ISD::CTSELECT && N1->hasOneUse()) {
+ SDValue N1_0 = N1->getOperand(0);
+ SDValue N1_1 = N1->getOperand(1);
+ SDValue N1_2 = N1->getOperand(2);
+ if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
+ // Create the actual and node if we can generate good code for it.
+ if (!normalizeToSequence) {
+ SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), And, N1_1,
+ N2, Flags);
+ }
+ // Otherwise see if we can optimize the "and" to a better pattern.
+ if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined,
+ N1_1, N2, Flags);
+ }
+ }
+ }
+ // ctselect Cond0, X, (ctselect Cond1, X, Y) -> ctselect (or Cond0, Cond1),
+ // X, Y
+ if (N2->getOpcode() == ISD::CTSELECT && N2->hasOneUse()) {
+ SDValue N2_0 = N2->getOperand(0);
+ SDValue N2_1 = N2->getOperand(1);
+ SDValue N2_2 = N2->getOperand(2);
+ if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
+ // Create the actual or node if we can generate good code for it.
+ if (!normalizeToSequence) {
+ SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Or, N1, N2_2,
+ Flags);
+ }
+ // Otherwise see if we can optimize to a better pattern.
+ if (SDValue Combined = visitORLike(N0, N2_0, DL))
+ return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined, N1,
+ N2_2, Flags);
+ }
+ }
+ }
+
+ return SDValue();
+}
+
// This function assumes all the vselect's arguments are CONCAT_VECTOR
// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 431a81002074f..8178fd8981519 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -4136,6 +4136,46 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
}
Results.push_back(Tmp1);
break;
+ case ISD::CTSELECT: {
+ Tmp1 = Node->getOperand(0);
+ Tmp2 = Node->getOperand(1);
+ Tmp3 = Node->getOperand(2);
+ EVT VT = Tmp2.getValueType();
+ if (VT.isVector()) {
+ SmallVector<SDValue> Elements;
+ unsigned NumElements = VT.getVectorNumElements();
+ EVT ScalarVT = VT.getScalarType();
+ for (unsigned Idx = 0; Idx < NumElements; ++Idx) {
+ SDValue IdxVal = DAG.getConstant(Idx, dl, MVT::i64);
+ SDValue TVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp2, IdxVal);
+ SDValue FVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp3, IdxVal);
+ Elements.push_back(
+ DAG.getCTSelect(dl, ScalarVT, Tmp1, TVal, FVal, Node->getFlags()));
+ }
+ Tmp1 = DAG.getBuildVector(VT, dl, Elements);
+ } else if (VT.isFloatingPoint()) {
+ EVT IntegerVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
+ Tmp2 = DAG.getBitcast(IntegerVT, Tmp2);
+ Tmp3 = DAG.getBitcast(IntegerVT, Tmp3);
+ Tmp1 = DAG.getBitcast(VT, DAG.getCTSelect(dl, IntegerVT, Tmp1, Tmp2, Tmp3,
+ Node->getFlags()));
+ } else {
+ assert(VT.isInteger());
+ EVT HalfVT = VT.getHalfSizedIntegerVT(*DAG.getContext());
+ auto [Tmp2Lo, Tmp2Hi] = DAG.SplitScalar(Tmp2, dl, HalfVT, HalfVT);
+ auto [Tmp3Lo, Tmp3Hi] = DAG.SplitScalar(Tmp3, dl, HalfVT, HalfVT);
+ SDValue ResLo =
+ DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Lo, Tmp3Lo, Node->getFlags());
+ SDValue ResHi =
+ DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Hi, Tmp3Hi, Node->getFlags());
+ Tmp1 = DAG.getNode(ISD::BUILD_PAIR, dl, VT, ResLo, ResHi);
+ Tmp1->setFlags(Node->getFlags());
+ }
+ Results.push_back(Tmp1);
+ break;
+ }
case ISD::BR_JT: {
SDValue Chain = Node->getOperand(0);
SDValue Table = Node->getOperand(1);
@@ -5474,7 +5514,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2));
break;
}
- case ISD::SELECT: {
+ case ISD::SELECT:
+ case ISD::CTSELECT: {
unsigned ExtOp, TruncOp;
if (Node->getValueType(0).isVector() ||
Node->getValueType(0).getSizeInBits() == NVT.getSizeInBits()) {
@@ -5492,7 +5533,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1));
Tmp3 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(2));
// Perform the larger operation, then round down.
- Tmp1 = DAG.getSelect(dl, NVT, Tmp1, Tmp2, Tmp3);
+ Tmp1 = DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1, Tmp2, Tmp3);
+ Tmp1->setFlags(Node->getFlags());
if (TruncOp != ISD::FP_ROUND)
Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1);
else
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 58983cb57d7f6..855a15a744cfe 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -159,6 +159,7 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
case ISD::ATOMIC_LOAD: R = SoftenFloatRes_ATOMIC_LOAD(N); break;
case ISD::ATOMIC_SWAP: R = BitcastToInt_ATOMIC_SWAP(N); break;
case ISD::SELECT: R = SoftenFloatRes_SELECT(N); break;
+ case ISD::CTSELECT: R = SoftenFloatRes_CTSELECT(N); break;
case ISD::SELECT_CC: R = SoftenFloatRes_SELECT_CC(N); break;
case ISD::FREEZE: R = SoftenFloatRes_FREEZE(N); break;
case ISD::STRICT_SINT_TO_FP:
@@ -1041,6 +1042,13 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_SELECT(SDNode *N) {
LHS.getValueType(), N->getOperand(0), LHS, RHS);
}
+SDValue DAGTypeLegalizer::SoftenFloatRes_CTSELECT(SDNode *N) {
+ SDValue LHS = GetSoftenedFloat(N->getOperand(1));
+ SDValue RHS = GetSoftenedFloat(N->getOperand(2));
+ return DAG.getCTSelect(SDLoc(N), LHS.getValueType(), N->getOperand(0), LHS,
+ RHS);
+}
+
SDValue DAGTypeLegalizer::SoftenFloatRes_SELECT_CC(SDNode *N) {
SDValue LHS = GetSoftenedFloat(N->getOperand(2));
SDValue RHS = GetSoftenedFloat(N->getOperand(3));
@@ -1561,6 +1569,7 @@ void DAGTypeLegalizer::ExpandFloatResult(SDNode *N, unsigned ResNo) {
case ISD::POISON:
case ISD::UNDEF: SplitRes_UNDEF(N, Lo, Hi); break;
case ISD::SELECT: SplitRes_Select(N, Lo, Hi); break;
+ case ISD::CTSELECT: SplitRes_Select(N, Lo, Hi); break;
case ISD::SELECT_CC: SplitRes_SELECT_CC(N, Lo, Hi); break;
case ISD::MERGE_VALUES: ExpandRes_MERGE_VALUES(N, ResNo, Lo, Hi); break;
@@ -2917,6 +2926,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
R = PromoteFloatRes_ATOMIC_LOAD(N);
break;
case ISD::SELECT: R = PromoteFloatRes_SELECT(N); break;
+ case ISD::CTSELECT:
+ R = PromoteFloatRes_SELECT(N);
+ break;
case ISD::SELECT_CC: R = PromoteFloatRes_SELECT_CC(N); break;
case ISD::SINT_TO_FP:
@@ -3219,7 +3231,7 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_SELECT(SDNode *N) {
SDValue TrueVal = GetPromotedFloat(N->getOperand(1));
SDValue FalseVal = GetPromotedFloat(N->getOperand(2));
- return DAG.getNode(ISD::SELECT, SDLoc(N), TrueVal->getValueType(0),
+ return DAG.getNode(N->getOpcode(), SDLoc(N), TrueVal->getValueType(0),
N->getOperand(0), TrueVal, FalseVal);
}
@@ -3403,6 +3415,9 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
R = SoftPromoteHalfRes_ATOMIC_LOAD(N);
break;
case ISD::SELECT: R = SoftPromoteHalfRes_SELECT(N); break;
+ case ISD::CTSELECT:
+ R = SoftPromoteHalfRes_SELECT(N);
+ break;
case ISD::SELECT_CC: R = SoftPromoteHalfRes_SELECT_CC(N); break;
case ISD::STRICT_SINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 44e5a187c4281..0135b3195438b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -95,6 +95,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
Res = PromoteIntRes_VECTOR_COMPRESS(N);
break;
case ISD::SELECT:
+ case ISD::CTSELECT:
case ISD::VSELECT:
case ISD::VP_SELECT:
case ISD::VP_MERGE:
@@ -2013,6 +2014,9 @@...
[truncated]
|
|
Is there a guarantee in SelectionDAG that bitwise operations will never turned back into a select? |
@topperc Are you asking if there's any guarantee that SelectionDAG won't fold the bitwise operations back into a SELECT node? If so, we use DAG chaining to create artificial dependencies between the instructions. This prevents later optimization Have you seen cases where this might not be sufficient or are there specific optimization passes we should be |
| APInt AllOnesVal = APInt::getAllOnes(BitWidth); | ||
| SDValue ScalarAllOnes = | ||
| DAG.getConstant(AllOnesVal, DL, WorkingVT.getScalarType()); | ||
| AllOnes = DAG.getSplatVector(WorkingVT, DL, ScalarAllOnes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are scalable vector special here? Isn't getAllOnesConstant supposed to handle this transparently for all types?
|
|
||
| // Handle floating-point types: bitcast to integer for bitwise operations | ||
| if (VT.isFloatingPoint()) { | ||
| if (VT.isVector()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use EVT::changeTypeToInteger?
| EVT CondVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElems); | ||
|
|
||
| if (VT.isScalableVector()) { | ||
| Cond = DAG.getSplatVector(CondVT, DL, Cond); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use DAG.getSplat which already handles the scalable vector difference?
| case Intrinsic::ct_select: { | ||
| // Set function attribute to indicate ct.select usage | ||
| Function &F = DAG.getMachineFunction().getFunction(); | ||
| F.addFnAttr("ct-select"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems a little weird for SelectionDAGBuilder to be modifying the IR function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is adding an attribute to MachineFunction that will be checked later during instructions bundling. We use instruction bundling in target specific lowering pipeline and will be part of subsequent PR.
| unsigned NumElements = VT.getVectorNumElements(); | ||
| EVT ScalarVT = VT.getScalarType(); | ||
| for (unsigned Idx = 0; Idx < NumElements; ++Idx) { | ||
| SDValue IdxVal = DAG.getConstant(Idx, dl, MVT::i64); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be using DAG.getVectorIdxConstant. MVT::i64 might not be a legal type.
As @wizardengineer mentioned, the DAG chaining does maintain artificial dependencies between the nodes this preventing the node merging and pattern detection to select. However, the future optimizations may detect the pattern and convert to select. But the this is fallback implementation and we are hoping in future every target will have their specific pipeline for lowering the ctselect intrinsic providing stronger guarantees. |
|
A couple of small nits from the PR description:
A constant-time select also wants to guarantee execution time independent of the values being selected! I guess you didn't mention that because you'd expect it from even the normal way to select between two values. But it wasn't guaranteed by the normal way, so perhaps it's worth including explicitly here anyway. Also it's probably worth mentioning that the "guarantee" is sometimes conditional, e.g. on Arm DIT being set while running the generated code. I expect that for instructions this simple it's probably OK in practice even without that, but that's not the same thing as a promise that the next-generation CPU won't find something clever to do. |
| @@ -783,6 +783,10 @@ enum NodeType { | |||
| /// i1 then the high bits must conform to getBooleanContents. | |||
| SELECT, | |||
|
|
|||
| /// Constant-time Select, implemented with CMOV instruction. This is used to | |||
| /// implement constant-time select. | |||
| CTSELECT, | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemented with CMOV instruction ? Your description says these will be expanded to bitselect patterns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, "Constant-time select [...] This is used to implement constant-time select" isn't really adding any value by repeating the same phrase again 🙂
Perhaps better to restate the order of parameters (it's obvious to you that it's the same as SELECT immediately above, but perhaps not to the next reader), and also, what conditions apply to the boolean – if it's not an i1, is it still expected to be an integer containing 0 or 1, or is it a bitmask containing 0 or ~0, or what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemented with CMOV instruction ? Your description says these will be expanded to bitselect patterns?
Good catch, that's not suppose to be there. Original when we were constructing the constant-time code, we had added implementation for the specific archs (like x86 in this case), the fallback implementation was added later. So I think that's why the comment mentions CMOV. I'll make sure to fix it, thanks! :)
| @@ -435,6 +435,9 @@ struct SDNodeFlags { | |||
| NonNeg | NoNaNs | NoInfs | SameSign | InBounds, | |||
| FastMathFlags = NoNaNs | NoInfs | NoSignedZeros | AllowReciprocal | | |||
| AllowContract | ApproximateFuncs | AllowReassociation, | |||
|
|
|||
| // Flag for disabling optimization | |||
| NoMerge = 1 << 15, | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be here, I'll make sure to remove it
| @@ -783,6 +783,10 @@ enum NodeType { | |||
| /// i1 then the high bits must conform to getBooleanContents. | |||
| SELECT, | |||
|
|
|||
| /// Constant-time Select, implemented with CMOV instruction. This is used to | |||
| /// implement constant-time select. | |||
| CTSELECT, | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, "Constant-time select [...] This is used to implement constant-time select" isn't really adding any value by repeating the same phrase again 🙂
Perhaps better to restate the order of parameters (it's obvious to you that it's the same as SELECT immediately above, but perhaps not to the next reader), and also, what conditions apply to the boolean – if it's not an i1, is it still expected to be an integer containing 0 or 1, or is it a bitmask containing 0 or ~0, or what?
| CtSelect, // The target implements a custom constant-time select. | ||
| ScalarCondVectorValCtSelect, // The target supports selects with a scalar | ||
| // condition and vector values. | ||
| VectorMaskValCtSelect, // The target supports vector selects with a vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: two of these three don't mention "constant-time" in their comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only rely on CtSelect to decide target lowering support for the CTSELECT dag node. Other two should be removed here. If needed for supporting other targets in future, we can introduce them again.
| virtual bool isSelectSupported(SelectSupportKind /*kind*/) const { | ||
| return true; | ||
| virtual bool isSelectSupported(SelectSupportKind kind) const { | ||
| return kind != CtSelect; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this not checking for all three of the new values you added? It looks as if will assume by default that every target supports ScalarCondVectorValCtSelect and VectorMaskValCtSelect.
| NoCapture<ArgIndex<0>>]>; | ||
|
|
||
| ///===-------------------------- Constant Time Intrinsics | ||
| ///--------------------------===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: your ASCII art has been accidentally word-wrapped 🙂
| SDValue visitCTTZ_ZERO_UNDEF(SDNode *N); | ||
| SDValue visitCTPOP(SDNode *N); | ||
| SDValue visitSELECT(SDNode *N); | ||
| SDValue visitCTSELECT(SDNode *N); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hate to say it, but is it possible that the naming of this node will cause confusion? Two lines above this visitCTSELECT I see the existing visitCTPOP, which looks very similar, but is totally different – your CT stands for "constant time" but theirs stands for "count".
| N2, Flags); | ||
| } | ||
| // Otherwise see if we can optimize the "and" to a better pattern. | ||
| if (SDValue Combined = visitANDLike(N0, N1_0, N)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you confident that this visitANDLike call is going to preserve constant-time semantics? (Not just now, but when someone editing visitANDLike has a clever idea next year. Is that function documented as requiring CT to be preserved?)
It would be a shame to take as input a perfectly safe double-CTSELECT and spit out a thing which had "helpfully" optimized the condition into something that wasn't constant-time any more.
(Same goes for the visitORLike below, of course.)
| // Invert mask for false value | ||
| SDValue Invert = DAG.getNode(ISD::XOR, DL, WorkingVT, Mask, AllOnes); | ||
|
|
||
| // Compute: (T & Mask) | (F & ~Mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better to compute F ^ ((T ^ F) & Mask)? That's what I normally do in my handwritten bit-twiddling selects. It's the same number of binary bitwise operations (two XORs and an AND, instead of two ANDs and an OR), but it avoids the extra unary operation of having to invert Mask.
| CanUseChaining = TLI.isTypeLegal(WorkingVT.getSimpleVT()); | ||
| } else { | ||
| // For scalable vectors, disable chaining (conservative approach) | ||
| CanUseChaining = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hang on, what happens to the CT guarantee in this branch? If this chaining plan is the argument for how subsequent optimizations aren't going to undo your good work, shouldn't there at least be a comment here explaining why it's OK to just not do it?
| CtSelect, // The target implements a custom constant-time select. | ||
| ScalarCondVectorValCtSelect, // The target supports selects with a scalar | ||
| // condition and vector values. | ||
| VectorMaskValCtSelect, // The target supports vector selects with a vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't use VectorMaskValCtSelect anywhere so we can remove it.
| case Intrinsic::ct_select: { | ||
| // Set function attribute to indicate ct.select usage | ||
| Function &F = DAG.getMachineFunction().getFunction(); | ||
| F.addFnAttr("ct-select"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is adding an attribute to MachineFunction that will be checked later during instructions bundling. We use instruction bundling in target specific lowering pipeline and will be part of subsequent PR.
| CtSelect, // The target implements a custom constant-time select. | ||
| ScalarCondVectorValCtSelect, // The target supports selects with a scalar | ||
| // condition and vector values. | ||
| VectorMaskValCtSelect, // The target supports vector selects with a vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only rely on CtSelect to decide target lowering support for the CTSELECT dag node. Other two should be removed here. If needed for supporting other targets in future, we can introduce them again.

Here's the updated PR description with the actual PR links:
This is in reference to our [RFC] Constant-Time Coding Support proposal.
Summary
This PR introduces core infrastructure for constant-time selection operations in LLVM, providing a foundation for
cryptographic code that prevents timing side-channels. This is the first PR in a stacked series implementing
comprehensive constant-time selection support across multiple architectures.
Changes
1. Core Intrinsic Definition
Adds the
llvm.ct.selectintrinsic family to LLVM IR:result = condition ? true_value : false_valuellvm/include/llvm/IR/Intrinsics.td2. Generic Fallback Implementation
Implements architecture-agnostic SelectionDAG lowering using bitwise operations:
result = (true_val & mask) | (false_val & ~mask)wheremask = -(condition)llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cppThe fallback implementation converts boolean conditions into bitmasks and uses bitwise arithmetic to achieve
constant-time selection:
condition (0 or 1) → NEG → mask (0x0000... or 0xFFFF...)
result = (true_val & mask) | (false_val & ~mask)
This approach guarantees:
3. Test Coverage
Includes basic test cases demonstrating fallback functionality:
llvm/test/CodeGen/RISCV/ctselect-fallback.ll- Generic fallback pattern verificationllvm/test/CodeGen/X86/ctselect.ll- Demonstrates generic lowering before optimizationThese tests verify that the intrinsic correctly lowers to constant-time bitwise operations on architectures without
native support.
Architecture Support
This PR provides the fallback implementation that works on all architectures. Subsequent PRs in the stack will add:
__builtin_ct_select)Security Properties
The fallback implementation ensures:
Related PRs
This is part of a stacked PR series implementing constant-time selection:
__builtin_ct_select)Testing
All changes pass existing regression tests. New tests verify:
References
cc @nikic, @dtcxzyw